# TGRS Reproducibility Package

Pipeline to reproduce **Teacher-Guided Rank Sensitivity (TGRS)**:
1) collect teacher-guided gradients
2) compute curvature proxy
3) run TGRS allocation and reconstruction
4) evaluate perplexity

This package uses Hugging Face–style model folders and public datasets.

---

## Repository Layout

```
code/
  collect_teacher_grads.py     # teacher-guided gradients
  compute_curvature.py         # diagonal curvature proxy
  tgrs.py                      # scoring, allocation, reconstruction
  eval_perplexity.py           # perplexity evaluation
work/                          # default output directory for intermediate files
models/                        # default output directory for reconstructed models
```

Script names, arguments, and default paths can be changed as needed.

---

## Environment

- Python: 3.9 or 3.10
- PyTorch: 2.2+ (CUDA 11.8+ or compatible)
- Transformers: 4.43+
- Datasets: 2.19+
- Tokenizers: 0.15+
- TQDM: 4.66+

### Installation (example with CUDA 11.8 wheels)

```bash
python -m venv .venv
source .venv/bin/activate
pip install --upgrade pip
pip install torch --index-url https://download.pytorch.org/whl/cu118
pip install transformers datasets tokenizers tqdm
```

If a different CUDA version is required, install the matching Torch wheel. CPU-only Torch is also supported, with longer runtimes.

---

## Prerequisites

- A local Hugging Face–style base model directory (e.g., `meta-llama/Llama-2-7b-hf` mirrored to disk).
- If access control applies, authenticate with Hugging Face and ensure the model is cached locally.
- Sufficient disk space for intermediate outputs and the reconstructed model folder.

---

## End-to-End Pipeline

**Order**: (1) teacher gradients → (2) curvature → (3) TGRS → (4) perplexity

### 1) Teacher-guided gradients

Outputs per-layer `.pt` tensors into `--out_grad_dir`.

```bash
python code/collect_teacher_grads.py \
  --teacher_dir /path/to/base/model \
  --student_dir /path/to/base/model \
  --out_grad_dir ./work/grads_kd \
  --dataset wikitext2 \
  --nsamples 64 \
  --seqlen 512 \
  --blocks q,k,v,o,up,down,gate \
  --batch_size 1 \
  --device cuda
```

**Arguments**
- `--teacher_dir`: path to teacher model folder (HF format).
- `--student_dir`: path to student model folder (HF format). Use the same path if gradients are used only for scoring.
- `--dataset`: dataset name (e.g., `wikitext2`).
- `--nsamples`: number of samples for gradient aggregation.
- `--seqlen`: sequence length per sample.
- `--blocks`: comma-separated projection tags to include.
- `--batch_size`: batch size during collection.
- `--device`: `cuda` or `cpu`.
- `--out_grad_dir`: output directory for gradient tensors.

**Output**
- `work/grads_kd/Lxx_<tag>.pt` per layer/tag.

---

### 2) Curvature proxy (diagonal)

Computes a diagonal curvature proxy (empirical Fisher if `--emp_fisher` is set).

```bash
python code/compute_curvature.py \
  --model_dir /path/to/base/model \
  --out_hess_dir ./work/hess_diag \
  --dataset wikitext2 \
  --nsamples 64 \
  --seqlen 512 \
  --blocks q,k,v,o,up,down,gate \
  --batch_size 1 \
  --emp_fisher \
  --device cuda
```

**Arguments**
- `--model_dir`: path to model folder (HF format).
- `--out_hess_dir`: output directory for curvature tensors.
- `--emp_fisher`: use average of squared gradients as diagonal proxy.
- Remaining flags mirror Step 1 where applicable.

**Output**
- `work/hess_diag/Lxx_<tag>.pt` per layer/tag.

The TGRS pipeline can run without curvature (omit `--hess_dir` or set `--beta 0.0`).

---

### 3) TGRS allocation and reconstruction

Consumes gradients (`--grad_dir`) and optional curvature (`--hess_dir`). Produces a reconstructed HF folder at `--out_dir`.

```bash
python -u code/tgrs.py \
  --base_model_dir /path/to/base/model \
  --grad_dir ./work/grads_kd \
  --hess_dir ./work/hess_diag \
  --out_dir ./models/tgrs_fold \
  --blocks all \
  --budget_avg_bits 3.8 \
  --q_bits 4 \
  --lr_bits 4 \
  --scale_bits 16 \
  --alpha 1.0 --beta 0.0 --gamma 1.0 \
  --allow_zero_rank \
  --rank_cap_map "q:16,k:8,v:12,o:16,up:8,down:8" \
  --score_device cuda --reconstruct_device cuda
```

**Key arguments**
- `--base_model_dir`: HF base model folder.
- `--grad_dir`: directory from Step 1.
- `--hess_dir`: directory from Step 2 (optional).
- `--out_dir`: output model folder (HF format).
- `--blocks`: which projections to include; `all` or `all+io` (adds `embed` and `lm_head` if supported).
- `--budget_avg_bits`: average bits-per-parameter target (Eq. (11) in the paper).
- `--q_bits`: bit-width for quantized backbone integers (`b_l`).
- `--lr_bits`: bit-width for low-rank factor integers (`b_r`).
- `--scale_bits`: bit-width for scales (`b_s`).
- `--alpha --beta --gamma`: non-negative aggregation coefficients for the three signals (gradients, curvature, spectral energy). If curvature is unavailable, set `--beta 0.0`.
- `--allow_zero_rank`: permit rank 0 for some layers.
- `--rank_cap_map`: per-tag caps for selected singular directions.
- `--q_bits_map` (optional): per-tag backbone bit-widths, e.g., `"up:2,down:2,q:4,k:4,v:4,o:4"`.
- `--score_device`, `--reconstruct_device`: devices for scoring and reconstruction.

**Outputs**
- `models/tgrs_fold/`: HF model folder with reconstructed weights.
- `models/tgrs_fold/quant_meta.json`: allocation summary (e.g., realized average bits, selected ranks, per-tag settings).

**Notes**
- If `--hess_dir` is not provided, the curvature term is skipped.
- If any tag in `--q_bits_map` uses `2` bits, the script can apply a randomized block-orthogonal preconditioner internally.

---

### 4) Perplexity evaluation (WikiText-2)

Evaluates perplexity of the reconstructed HF model folder.

```bash
python code/eval_perplexity.py \
  --model_dir ./models/tgrs_fold \
  --dataset wikitext2 \
  --nsamples 128 \
  --seqlen 512 \
  --batch_size 1 \
  --device cuda
```

**Output**
- Prints `ppl=...` to stdout.

---

## Export Modes (optional)

`code/tgrs.py` can export different artifacts when `--deploy_mode` is set:

- `fold` (default): writes reconstructed, quantized weights into an HF folder.
- `pack`: writes a packed artifact (indices/scales, low-rank factors); requires `safetensors`.
- `export`: writes only low-rank adapters (`A/B`) and a manifest.
- `jetson`: writes INT8 adapters and static per-tensor scales for TensorRT compatibility.

**Examples**

```bash
# pack
python -u code/tgrs.py \
  --base_model_dir /path/to/model \
  --grad_dir ./work/grads_kd \
  --hess_dir ./work/hess_diag \
  --out_dir ./models/q_pack \
  --blocks all \
  --q_bits 4 --q_bits_map "up:2,down:2" \
  --lr_bits 4 \
  --scale_bits 16 \
  --budget_avg_bits 3.8 \
  --deploy_mode pack

# export
python -u code/tgrs.py \
  --base_model_dir /path/to/model \
  --grad_dir ./work/grads_kd \
  --hess_dir ./work/hess_diag \
  --out_dir ./models/q_export \
  --blocks all \
  --q_bits 4 \
  --lr_bits 4 \
  --scale_bits 16 \
  --budget_avg_bits 3.8 \
  --deploy_mode export

# jetson
python -u code/tgrs.py \
  --base_model_dir /path/to/model \
  --grad_dir ./work/grads_kd \
  --hess_dir ./work/hess_diag \
  --out_dir ./models/q_jetson \
  --blocks all \
  --q_bits 4 \
  --lr_bits 8 \
  --scale_bits 16 \
  --budget_avg_bits 4.0 \
  --deploy_mode jetson --jetson_static_scale
```

**Artifacts by mode**
- `fold`: HF folder + `quant_meta.json`.
- `pack`: `model.quant.safetensors` + `quant_meta.json`.
- `export`: `adapters/*.pt` + `adapters_manifest.json` + `quant_meta.json`.
- `jetson`: `adapters/jetson_*.pt` + `jetson_manifest.json` + `quant_meta.json`.

---

## Reproducibility Notes

- Set seeds in scripts if exact determinism is required.
- Report `--budget_avg_bits`, `--q_bits`, `--lr_bits`, `--scale_bits`, and any `--q_bits_map` used.
- Record the lists of blocks/tags and per-tag rank caps (`--rank_cap_map`).
- Record `nsamples`, `seqlen`, and dataset versions for Steps 1–2.
- Hardware, driver, and library versions can affect timing results.